"""remove instructions index

Revision ID: 7cf3054cbbcc
Revises: b9e516e2d3b3
Create Date: 2025-02-09 15:31:00.772295

"""

import sqlalchemy as sa
from sqlalchemy import orm
from alembic import op
from mealie.db.models._model_utils.guid import GUID
from mealie.core.root_logger import get_logger

# revision identifiers, used by Alembic.
revision = "7cf3054cbbcc"
down_revision: str | None = "b9e516e2d3b3"
branch_labels: str | tuple[str, ...] | None = None
depends_on: str | tuple[str, ...] | None = None

logger = get_logger()


class SqlAlchemyBase(orm.DeclarativeBase):
    @classmethod
    def normalized_fields(cls) -> list[orm.InstrumentedAttribute]:
        return []


class RecipeModel(SqlAlchemyBase):
    __tablename__ = "recipes"

    id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate)
    name_normalized: orm.Mapped[str] = orm.mapped_column(sa.String, nullable=False, index=True)
    description_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)

    @classmethod
    def normalized_fields(cls):
        return [cls.name_normalized, cls.description_normalized]


class RecipeIngredientModel(SqlAlchemyBase):
    __tablename__ = "recipes_ingredients"

    id: orm.Mapped[int] = orm.mapped_column(sa.Integer, primary_key=True)
    note_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)
    original_text_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)

    @classmethod
    def normalized_fields(cls):
        return [cls.note_normalized, cls.original_text_normalized]


class IngredientFoodModel(SqlAlchemyBase):
    __tablename__ = "ingredient_foods"
    id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate)
    name_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)
    plural_name_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)

    @classmethod
    def normalized_fields(cls):
        return [cls.name_normalized, cls.plural_name_normalized]


class IngredientFoodAliasModel(SqlAlchemyBase):
    __tablename__ = "ingredient_foods_aliases"
    id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate)
    name_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)

    @classmethod
    def normalized_fields(cls):
        return [cls.name_normalized]


class IngredientUnitModel(SqlAlchemyBase):
    __tablename__ = "ingredient_units"
    id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate)
    name_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)
    plural_name_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)
    abbreviation_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)
    plural_abbreviation_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)

    @classmethod
    def normalized_fields(cls):
        return [
            cls.name_normalized,
            cls.plural_name_normalized,
            cls.abbreviation_normalized,
            cls.plural_abbreviation_normalized,
        ]


class IngredientUnitAliasModel(SqlAlchemyBase):
    __tablename__ = "ingredient_units_aliases"
    id: orm.Mapped[GUID] = orm.mapped_column(GUID, primary_key=True, default=GUID.generate)
    name_normalized: orm.Mapped[str | None] = orm.mapped_column(sa.String, index=True)

    @classmethod
    def normalized_fields(cls):
        return [cls.name_normalized]


def truncate_normalized_fields() -> None:
    bind = op.get_bind()
    session = orm.Session(bind=bind)

    models: list[type[SqlAlchemyBase]] = [
        RecipeModel,
        RecipeIngredientModel,
        IngredientFoodModel,
        IngredientFoodAliasModel,
        IngredientUnitModel,
        IngredientUnitAliasModel,
    ]

    for model in models:
        for record in session.query(model).all():
            for field in model.normalized_fields():
                if not (field_value := getattr(record, field.key)):
                    continue

                setattr(record, field.key, field_value[:255])

        try:
            session.commit()
        except Exception:
            logger.exception(f"Failed to truncate normalized fields for {model.__name__}")
            session.rollback()


def upgrade():
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("recipe_instructions", schema=None) as batch_op:
        batch_op.drop_index("ix_recipe_instructions_text")

    # ### end Alembic commands ###

    truncate_normalized_fields()


def downgrade():
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("recipe_instructions", schema=None) as batch_op:
        batch_op.create_index("ix_recipe_instructions_text", ["text"], unique=False)

    # ### end Alembic commands ###
